{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TAG Network Selection NeurIPS Submission.ipynb",
      "provenance": [
        {
          "file_id": "1dCSaXl7r2cQrYQynjjtdwIKaq1huEDkw",
          "timestamp": 1622695458414
        },
        {
          "file_id": "1ZAsAeHeZwqS-dV_GrIimYepJMZiaNz3G",
          "timestamp": 1618806418512
        },
        {
          "file_id": "1S4fmk97iAXQ_WrpWxIgPLbF1R8bx_MQU",
          "timestamp": 1611821985296
        }
      ],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SqZbhX1XYzlj"
      },
      "source": [
        "### TAG Network Selection for...\n",
        "##Efficiently Identifying Task Groupings for Multi-Task Learning\n",
        "\n",
        "Licensed under the Apache License, Version 2.0\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0pxANRw6W_FB"
      },
      "source": [
        "### CelebA"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "O12pDbuVLBgX"
      },
      "source": [
        "# TAG average inter-task affinities.\n",
        "revised_integrals = {'5_o_Clock_Shadow': {'5_o_Clock_Shadow': 239.72068958020347, 'Black_Hair': -83.33548677967012, 'Blond_Hair': -86.0517490934864, 'Brown_Hair': -82.42068049883574, 'Goatee': -81.97442569877438, 'Mustache': -86.02419971908775, 'No_Beard': -49.037597943401195, 'Rosy_Cheeks': -87.3711143500207, 'Wearing_Hat': -86.77026014855653}, 'Black_Hair': {'5_o_Clock_Shadow': -188.15657486387036, 'Black_Hair': 378.8511885198165, 'Blond_Hair': -184.52892836169266, 'Brown_Hair': -147.50232270522932, 'Goatee': -189.464933472838, 'Mustache': -190.9978380658358, 'No_Beard': -185.28495779578083, 'Rosy_Cheeks': -190.6576875894878, 'Wearing_Hat': -189.41933750548824}, 'Blond_Hair': {'5_o_Clock_Shadow': -121.05393759410462, 'Black_Hair': -109.58362502552134, 'Blond_Hair': 310.7120317369794, 'Brown_Hair': -106.02204254857973, 'Goatee': -122.62230998931582, 'Mustache': -124.40291642298871, 'No_Beard': -117.20407488470592, 'Rosy_Cheeks': -120.55632616819761, 'Wearing_Hat': -123.25851616996826}, 'Brown_Hair': {'5_o_Clock_Shadow': -143.72124825753016, 'Black_Hair': -112.17930410506702, 'Blond_Hair': -139.84807280830086, 'Brown_Hair': 321.53891335979023, 'Goatee': -145.42740587751805, 'Mustache': -146.78348353101595, 'No_Beard': -142.3971373230815, 'Rosy_Cheeks': -146.29884678475298, 'Wearing_Hat': -146.22726918591053}, 'Goatee': {'5_o_Clock_Shadow': -85.53234687179813, 'Black_Hair': -92.21160616332925, 'Blond_Hair': -95.75043274633258, 'Brown_Hair': -91.74112249406916, 'Goatee': 238.42495844470608, 'Mustache': -65.42735927466887, 'No_Beard': -45.95103369426986, 'Rosy_Cheeks': -98.59883005339663, 'Wearing_Hat': -95.6828851878095}, 'Mustache': {'5_o_Clock_Shadow': -30.223281466637463, 'Black_Hair': -32.697193679549336, 'Blond_Hair': -34.78380316865219, 'Brown_Hair': -33.47520370724056, 'Goatee': -0.8396637863075016, 'Mustache': 178.66788767657505, 'No_Beard': -15.271524701599517, 'Rosy_Cheeks': -35.20658096923006, 'Wearing_Hat': -34.37035408412552}, 'No_Beard': {'5_o_Clock_Shadow': -109.47881478244865, 'Black_Hair': -147.50179136978863, 'Blond_Hair': -152.39464713301385, 'Brown_Hair': -148.48881711332825, 'Goatee': -127.31550988605389, 'Mustache': -147.13584310859744, 'No_Beard': 291.8495619980659, 'Rosy_Cheeks': -157.47676146469695, 'Wearing_Hat': -156.15812124227463}, 'Rosy_Cheeks': {'5_o_Clock_Shadow': -74.97883839974561, 'Black_Hair': -74.24941175274341, 'Blond_Hair': -72.04941829851619, 'Brown_Hair': -73.83159368156338, 'Goatee': -76.45193929878833, 'Mustache': -76.49830443422142, 'No_Beard': -75.21144504121774, 'Rosy_Cheeks': 240.390333377952, 'Wearing_Hat': -77.00375527569419}, 'Wearing_Hat': {'5_o_Clock_Shadow': -109.37079668739855, 'Black_Hair': -100.84965267237799, 'Blond_Hair': -113.52459033628551, 'Brown_Hair': -110.64525748395384, 'Goatee': -110.26379082868624, 'Mustache': -116.10172294652246, 'No_Beard': -106.51167682200729, 'Rosy_Cheeks': -119.35220036279986, 'Wearing_Hat': 374.6323404450723}}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z-PwWtUG7FvS"
      },
      "source": [
        "# CS Grouping.\n",
        "revised_integrals = {'5_o_Clock_Shadow': {'5_o_Clock_Shadow': 22.00000022440799, 'Black_Hair': 0.25081617948664875, 'Blond_Hair': 0.2061718919411787, 'Brown_Hair': 0.2881746112139813, 'Goatee': 0.625541867401377, 'Mustache': 0.38242295235962176, 'No_Beard': 2.1118374600002627, 'Rosy_Cheeks': 0.1492321557089947, 'Wearing_Hat': 0.29801101447452005}, 'Black_Hair': {'5_o_Clock_Shadow': 0.25081617948664875, 'Black_Hair': 22.00000018551502, 'Blond_Hair': 0.42110319038773486, 'Brown_Hair': 1.369120119153668, 'Goatee': 0.24038168565267032, 'Mustache': 0.15441270535298604, 'No_Beard': 0.3694005647915243, 'Rosy_Cheeks': 0.13147151774684296, 'Wearing_Hat': 0.2561181352237544}, 'Blond_Hair': {'5_o_Clock_Shadow': 0.2061718919411787, 'Black_Hair': 0.42110319038773486, 'Blond_Hair': 22.000000177970726, 'Brown_Hair': 0.5411562459181858, 'Goatee': 0.17836285116075049, 'Mustache': 0.10293958113919395, 'No_Beard': 0.3570151728601598, 'Rosy_Cheeks': 0.2991049753957141, 'Wearing_Hat': 0.1989976827574676}, 'Brown_Hair': {'5_o_Clock_Shadow': 0.2881746112139813, 'Black_Hair': 1.369120119153668, 'Blond_Hair': 0.5411562459181858, 'Brown_Hair': 22.000000190247775, 'Goatee': 0.2258797879312138, 'Mustache': 0.1033378763023164, 'No_Beard': 0.3621806593964234, 'Rosy_Cheeks': 0.13119629044801062, 'Wearing_Hat': 0.16213748738052247}, 'Goatee': {'5_o_Clock_Shadow': 0.625541867401377, 'Black_Hair': 0.24038168565267032, 'Blond_Hair': 0.17836285116075049, 'Brown_Hair': 0.2258797879312138, 'Goatee': 22.00000020739817, 'Mustache': 2.885253067680917, 'No_Beard': 1.957515441464863, 'Rosy_Cheeks': 0.08417127602449773, 'Wearing_Hat': 0.36185996873670195}, 'Mustache': {'5_o_Clock_Shadow': 0.38242295235962176, 'Black_Hair': 0.15441270535298604, 'Blond_Hair': 0.10293958113919395, 'Brown_Hair': 0.1033378763023164, 'Goatee': 2.885253067680917, 'Mustache': 22.00000018340637, 'No_Beard': 0.9277471130069601, 'Rosy_Cheeks': 0.11723857216379202, 'Wearing_Hat': 0.24007541606124794}, 'No_Beard': {'5_o_Clock_Shadow': 2.1118374600002627, 'Black_Hair': 0.3694005647915243, 'Blond_Hair': 0.3570151728601598, 'Brown_Hair': 0.3621806593964234, 'Goatee': 1.957515441464863, 'Mustache': 0.9277471130069601, 'No_Beard': 22.000000198166955, 'Rosy_Cheeks': 0.15052633990539294, 'Wearing_Hat': 0.2759176363075264}, 'Rosy_Cheeks': {'5_o_Clock_Shadow': 0.1492321557089947, 'Black_Hair': 0.13147151774684296, 'Blond_Hair': 0.2991049753957141, 'Brown_Hair': 0.13119629044801062, 'Goatee': 0.08417127602449773, 'Mustache': 0.11723857216379202, 'No_Beard': 0.15052633990539294, 'Rosy_Cheeks': 22.00000021428645, 'Wearing_Hat': 0.07270302767988278}, 'Wearing_Hat': {'5_o_Clock_Shadow': 0.29801101447452005, 'Black_Hair': 0.2561181352237544, 'Blond_Hair': 0.1989976827574676, 'Brown_Hair': 0.16213748738052247, 'Goatee': 0.36185996873670195, 'Mustache': 0.24007541606124794, 'No_Beard': 0.2759176363075264, 'Rosy_Cheeks': 0.07270302767988278, 'Wearing_Hat': 22.000000216582542}}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X3GYdsYvXDec"
      },
      "source": [
        "### Taskonomy"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZdZjTo8bXGrI",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1631771756257,
          "user_tz": 420,
          "elapsed": 57,
          "user": {
            "displayName": "Christopher Fifty",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiYYrGujgM-yL4Ol5R-LyAzmYiPF4BIVzCq-g4=s64",
            "userId": "11483318465600365808"
          }
        }
      },
      "source": [
        "# TAG average inter-task affinities.\n",
        "revised_integrals = {'ss_l': {'ss_l': 11.999321423965947, 'norm_l': 0.04891036220837213, 'depth_l': 0.11537429894122785, 'key_l': 0.00454724706606927, 'edge2d_l': 0.008776140044461892}, 'depth_l': {'ss_l': 0.41355850961858076, 'norm_l': 0.2914147606289435, 'depth_l': 7.3276368896443556, 'key_l': 0.01353652075974951, 'edge2d_l': 0.07395688217092156}, 'norm_l': {'ss_l': 0.13982367687578146, 'norm_l': 1.070989344832557, 'depth_l': 0.22096194178241185, 'key_l': 0.010493772045577332, 'edge2d_l': 0.07645966827511533}, 'key_l': {'ss_l': 0.021098526716001414, 'norm_l': 0.010729013557423, 'depth_l': 0.013365983131803236, 'key_l': 0.7338844180545812, 'edge2d_l': 0.6897358940641894}, 'edge2d_l': {'ss_l': 0.056692022190245474, 'norm_l': 0.03600666458561003, 'depth_l': 0.03668338077695892, 'key_l': 0.30372134710544996, 'edge2d_l': 5.9661909882444615}}"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pMHTyhGpXvCM"
      },
      "source": [
        "# CS Grouping.\n",
        "revised_integrals = {'ss_l': {'ss_l': 9.799051248375491, 'norm_l': 0.24237152135772985, 'depth_l': 0.259569212601681, 'key_l': 0.054060216676064456, 'edge2d_l': 0.08864023100811838}, 'depth_l': {'ss_l': 0.259569212601681, 'norm_l': 0.9115829230192612, 'depth_l': 9.799051266535107, 'key_l': 0.09153054737281137, 'edge2d_l': 0.15646468903944455}, 'norm_l': {'ss_l': 0.24237152135772985, 'norm_l': 9.799051265833779, 'depth_l': 0.9115829230192612, 'key_l': 0.14588496629635408, 'edge2d_l': 0.28013964222462745}, 'key_l': {'ss_l': 0.054060216676064456, 'norm_l': 0.14588496629635408, 'depth_l': 0.09153054737281137, 'key_l': 9.79905053927595, 'edge2d_l': 2.9680815659422954}, 'edge2d_l': {'ss_l': 0.08864023100811838, 'norm_l': 0.28013964222462745, 'depth_l': 0.15646468903944455, 'key_l': 2.9680815659422954, 'edge2d_l': 9.79904991291345}}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RYvGLo5mX5gp"
      },
      "source": [
        "### Network Selection"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DhdVOK1lNqzC",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1631771759228,
          "user_tz": 420,
          "elapsed": 57,
          "user": {
            "displayName": "Christopher Fifty",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiYYrGujgM-yL4Ol5R-LyAzmYiPF4BIVzCq-g4=s64",
            "userId": "11483318465600365808"
          }
        }
      },
      "source": [
        "import math \n",
        "import numpy as np\n",
        "\n",
        "def gen_task_combinations(tasks, rtn, index, path, path_dict):\n",
        "  if index >= len(tasks):\n",
        "    return \n",
        "\n",
        "  for i in range(index, len(tasks)):\n",
        "    cur_task = tasks[i]\n",
        "    new_path = path\n",
        "    new_dict = {k:v for k,v in path_dict.items()}\n",
        "    \n",
        "    # Building from a tree with two or more tasks...\n",
        "    if new_path:\n",
        "      new_dict[cur_task] = 0.\n",
        "      for prev_task in path_dict:\n",
        "        new_dict[prev_task] += revised_integrals[prev_task][cur_task]\n",
        "        new_dict[cur_task] += revised_integrals[cur_task][prev_task]\n",
        "      new_path = '{}|{}'.format(new_path, cur_task)\n",
        "      rtn[new_path] = new_dict\n",
        "    else: # First element in a new-formed tree\n",
        "      new_dict[cur_task] = 0.\n",
        "      new_path = cur_task\n",
        "\n",
        "    gen_task_combinations(tasks, rtn, i+1, new_path, new_dict)\n",
        "\n",
        "\n",
        "    if '|' not in new_path:\n",
        "      new_dict[cur_task] = -1e6 \n",
        "      rtn[new_path] = new_dict"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NuRJiEmxWKMZ"
      },
      "source": [
        "rtn = {}\n",
        "tasks = list(revised_integrals.keys())\n",
        "num_tasks = len(tasks)\n",
        "task_combinations = gen_task_combinations(tasks=tasks, rtn=rtn, index=0, path='', path_dict={})\n",
        "\n",
        "# Normalize by the number of times the accuracy of any given element has been summed. \n",
        "# i.e. (a,b,c) => [acc(a|b) + acc(a|c)]/2 + [acc(b|a) + acc(b|c)]/2 + [acc(c|a) + acc(c|b)]/2\n",
        "for group in rtn:\n",
        "  if '|' in group:\n",
        "    for task in rtn[group]:\n",
        "      rtn[group][task] /= (len(group.split('|')) - 1)\n",
        "\n",
        "print(rtn)\n",
        "assert(len(rtn.keys()) == 2**len(revised_integrals.keys()) - 1)\n",
        "rtn_tup = [(key,val) for key,val in rtn.items()]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WvTgeaR3KWMd"
      },
      "source": [
        "def select_groups(index, cur_group, best_group, best_val, splits):\n",
        "  # Check if this group covers all tasks.\n",
        "  task_set = set()\n",
        "  for group in cur_group:\n",
        "    for task in group.split('|'): task_set.add(task)\n",
        "  if len(task_set) == num_tasks:\n",
        "    best_tasks = {task:-1e6 for task in task_set}\n",
        "    \n",
        "    # Compute the per-task best scores for each task and average them together.\n",
        "    for group in cur_group:\n",
        "      for task in cur_group[group]:\n",
        "        best_tasks[task] = max(best_tasks[task], cur_group[group][task])\n",
        "    group_avg = np.mean(list(best_tasks.values()))\n",
        "    \n",
        "    # Compare with the best grouping seen thus far.\n",
        "    if group_avg > best_val[0]:\n",
        "      print(cur_group)\n",
        "      best_val[0] = group_avg\n",
        "      best_group.clear()\n",
        "      for entry in cur_group:\n",
        "        best_group[entry] = cur_group[entry]\n",
        "  \n",
        "  # Base case.\n",
        "  if len(cur_group.keys()) == splits:\n",
        "    return\n",
        "\n",
        "  # Back to combinatorics \n",
        "  for i in range(index, len(rtn_tup)):\n",
        "    selected_group, selected_dict = rtn_tup[i]\n",
        "\n",
        "    new_group = {k:v for k,v in cur_group.items()}\n",
        "    new_group[selected_group] = selected_dict\n",
        "\n",
        "    if len(new_group.keys()) <= splits:\n",
        "      select_groups(i + 1, new_group, best_group, best_val, splits)\n",
        "\n",
        "selected_group = {}\n",
        "selected_val = [-100000000]\n",
        "select_groups(index=0, cur_group={}, best_group=selected_group, best_val=selected_val, splits=3)\n",
        "print(selected_group)\n",
        "print(selected_val)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}